The purpose of this notebook is to infer the rate at which confirmed cases of COVID-19 are growing (or were growing) in various countries.
The notebook pulls data from the Johns Hopkins Data Repository of global Coronavirus COVID-19 cases, and then does the following things:
We then repeat these steps for US states.
The notebook is updated approximately daily.
For a great primer on exponential and logistic growth, watch this video.
The growth rate (and the doubling time) changes with time. As the exponential curve eventually turns into a logistic curve, the growth rate will shrink to zero (& the doubling time will consequently increase). So it's not a good idea to extrapolate trends far into the future based on current growth rates or doubling times.
The confirmed cases reported by each country are not the number of infections in each country, only those that have tested positive.
The doubling time calculated here measures the growth of cumulative confirmed cases, which is different from the growth of infections. For example, if a country suddenly ramps up testing, then the number of confirmed cases will rapidly rise, but infections may not be rising as the same rate.
The doubling times inferred from the curve fits are not necessarily the current or most recent doubling times:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from termcolor import colored, cprint
import plotly.graph_objects as go
#import plotly.offline as offline
#offline.init_notebook_mode(connected=True)
def logistic(t, a, b, c, d):
return c + (d - c)/(1 + a * np.exp(- b * t))
def exponential(t, a, b, c):
return a * np.exp(b * t) + c
def plotCases(dataframe, column, country, maxfev=100000):
#def plotCases(dataframe, column, country, maxfev=1):
co = dataframe[dataframe[column] == country].iloc[:,4:].T.sum(axis = 1)
co = pd.DataFrame(co)
co.columns = ['Cases']
co = co.loc[co['Cases'] > 0]
y = np.array(co['Cases'])
x = np.arange(y.size)
recentdbltime = float('NaN')
if len(y) >= 7:
current = y[-1]
lastweek = y[-8]
if current > lastweek:
print('\n** Based on Most Recent Week of Data **\n')
print('\tConfirmed cases on',co.index[-1],'\t',current)
print('\tConfirmed cases on',co.index[-8],'\t',lastweek)
ratio = current/lastweek
print('\tRatio:',round(ratio,2))
print('\tWeekly increase:',round( 100 * (ratio - 1), 1),'%')
dailypercentchange = round( 100 * (pow(ratio, 1/7) - 1), 1)
print('\tDaily increase:', dailypercentchange, '% per day')
recentdbltime = round( 7 * np.log(2) / np.log(ratio), 1)
print('\tDoubling Time (represents recent growth):',recentdbltime,'days')
#plt.figure(figsize=(10,5))
#plt.plot(x, y, 'ko', label="Original Data")
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name='Original Data'))
logisticworked = False
exponentialworked = False
try:
lpopt, lpcov = curve_fit(logistic, x, y, maxfev=maxfev)
lerror = np.sqrt(np.diag(lpcov))
# for logistic curve at half maximum, slope = growth rate/2. so doubling time = ln(2) / (growth rate/2)
ldoubletime = np.log(2)/(lpopt[1]/2)
# standard error
ldoubletimeerror = 1.96 * ldoubletime * np.abs(lerror[1]/lpopt[1])
# calculate R^2
residuals = y - logistic(x, *lpopt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y))**2)
logisticr2 = 1 - (ss_res / ss_tot)
if logisticr2 > 0.95:
#plt.plot(x, logistic(x, *lpopt), 'b--', label="Logistic Curve Fit")
fig.add_trace(go.Scatter(x=x, y=logistic(x, *lpopt), mode='lines', line=dict(dash='dot'), name="Logistic Curve Fit") )
print('\n** Based on Logistic Fit**\n')
print('\tR^2:', logisticr2)
print('\tDoubling Time (during middle of growth): ', round(ldoubletime,2), '(±', round(ldoubletimeerror,2),') days')
print("\tparam: ", lpopt)
logisticworked = True
else:
print("\n logistic R^2 ", logisticr2)
except Exception as ex:
cprint('\nException in logstic process ', 'red')
cprint(type(ex), 'red')
cprint(ex, 'red')
try:
epopt, epcov = curve_fit(exponential, x, y, bounds=([0,0,-100],[100,0.9,100]), maxfev=maxfev)
eerror = np.sqrt(np.diag(epcov))
# for exponential curve, slope = growth rate. so doubling time = ln(2) / growth rate
edoubletime = np.log(2)/epopt[1]
# standard error
edoubletimeerror = 1.96 * edoubletime * np.abs(eerror[1]/epopt[1])
# calculate R^2
residuals = y - exponential(x, *epopt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y))**2)
expr2 = 1 - (ss_res / ss_tot)
if expr2 > 0.95:
#plt.plot(x, exponential(x, *epopt), 'r--', label="Exponential Curve Fit")
fig.add_trace(go.Scatter(x=x, y=exponential(x, *epopt), mode='lines', line=dict(dash='dot'), name="Exponential Curve Fit"))
print('\n** Based on Exponential Fit **\n')
print('\tR^2:', expr2)
print('\tDoubling Time (represents overall growth): ', round(edoubletime,2), '(±', round(edoubletimeerror,2),') days')
print("\tparam: ", epopt)
exponentialworked = True
else:
print("\n exponential R^2 ", expr2)
except Exception as ex:
cprint('\nException in exponential process ', 'red')
cprint(type(ex), 'red')
cprint(ex, 'red')
#plt.title(country + ' Cumulative COVID-19 Cases. (Updated on '+mostrecentdate+')', fontsize="x-large")
#plt.xlabel('Days', fontsize="x-large")
#plt.ylabel('Total Cases', fontsize="x-large")
#plt.legend(fontsize="x-large")
#plt.show()
fig.update_layout(title=country + ' Cumulative COVID-19 Cases. (Updated on '+mostrecentdate+')'
, xaxis_title='Days'
, yaxis_title='Total Cases'
, width=900, height=700, autosize=False
#,paper_bgcolor='black'
)
fig.show()
if logisticworked and exponentialworked:
if round(logisticr2,2) > round(expr2,2):
return [ldoubletime, ldoubletimeerror, recentdbltime]
else:
return [edoubletime, edoubletimeerror, recentdbltime]
if logisticworked:
return [ldoubletime, ldoubletimeerror, recentdbltime]
if exponentialworked:
return [edoubletime, edoubletimeerror, recentdbltime]
else:
return [float('NaN'), float('NaN'), recentdbltime]
datadir = 'https://github.com/CSSEGISandData/COVID-19/raw/master/csse_covid_19_data/csse_covid_19_time_series/'
df = pd.read_csv( datadir + 'time_series_covid19_confirmed_US.csv')
#df = pd.read_csv( datadir + 'time_series_covid19_deaths_US.csv')
uscases = df
# For some reason they change column names, change them back.
uscases.rename(columns={
'Country_Region':'Country/Region',
'Province_State':'Province/State',
'Long_':'Long'
}, inplace=True)
# US states lookup from https://code.activestate.com/recipes/577305-python-dictionary-of-us-states-and-territories/
# with DC added
states = { 'AK': 'Alaska', 'AL': 'Alabama', 'AR': 'Arkansas', 'AS': 'American Samoa', 'AZ': 'Arizona', 'CA': 'California', 'CO': 'Colorado', 'CT': 'Connecticut', 'DC': 'District of Columbia', 'DE': 'Delaware', 'FL': 'Florida', 'GA': 'Georgia', 'GU': 'Guam', 'HI': 'Hawaii', 'IA': 'Iowa', 'ID': 'Idaho', 'IL': 'Illinois', 'IN': 'Indiana', 'KS': 'Kansas', 'KY': 'Kentucky', 'LA': 'Louisiana', 'MA': 'Massachusetts', 'MD': 'Maryland', 'ME': 'Maine', 'MI': 'Michigan', 'MN': 'Minnesota', 'MO': 'Missouri', 'MP': 'Northern Mariana Islands', 'MS': 'Mississippi', 'MT': 'Montana', 'NA': 'National', 'NC': 'North Carolina', 'ND': 'North Dakota', 'NE': 'Nebraska', 'NH': 'New Hampshire', 'NJ': 'New Jersey', 'NM': 'New Mexico', 'NV': 'Nevada', 'NY': 'New York', 'OH': 'Ohio', 'OK': 'Oklahoma', 'OR': 'Oregon', 'PA': 'Pennsylvania', 'PR': 'Puerto Rico', 'RI': 'Rhode Island', 'SC': 'South Carolina', 'SD': 'South Dakota', 'TN': 'Tennessee', 'TX': 'Texas', 'UT': 'Utah', 'VA': 'Virginia', 'VI': 'Virgin Islands', 'VT': 'Vermont', 'WA': 'Washington', 'WI': 'Wisconsin', 'WV': 'West Virginia', 'WY': 'Wyoming', 'D.C.': 'District of Columbia'}
# global
# Province/State,Country/Region,Lat,Long,1/22/20
# US
# UID,iso2,iso3,code3,FIPS,Admin2,Province_State,Country_Region,Lat,Long_,Combined_Key,1/22/2020
# Province/State Country/Region Lat Long 1/22/20
uscases = uscases.drop(columns=['UID', 'iso2', 'iso3', 'code3', 'FIPS', 'Admin2', 'Combined_Key'])
usstatesummary = uscases.iloc[:,[0,-1]].groupby('Province/State').sum()
mostrecentdate = usstatesummary.columns[0]
usstatesummary = usstatesummary.sort_values(by = mostrecentdate, ascending = False)
usstatesummary = usstatesummary[usstatesummary[mostrecentdate] > 0]
print('\nNumber of confirmed US COVID-19 cases by state as of', mostrecentdate)
usstatesummary
topusstates = usstatesummary[usstatesummary[mostrecentdate] >= 100]
print(topusstates)
print('\n');
inferreddoublingtime = []
recentdoublingtime = []
errors = []
states = []
for state in topusstates.index.values:
print('US state: ', state)
a = plotCases(uscases,'Province/State', state)
if a:
states.append(state)
inferreddoublingtime.append(a[0])
errors.append(a[1])
recentdoublingtime.append(a[2])
print('\n')
d = {'States': states, 'Inferred Doubling Time': inferreddoublingtime, '95%CI': errors, 'Recent Doubling Time': recentdoublingtime}
print('\nInferred Doubling Times are inferred using curve fits.')
print('Recent Doubling Times are calculated using the most recent week of data.')
print('Shorter doubling time = faster growth, longer doubling time = slower growth.')
print('\n')
print(pd.DataFrame(data=d).iloc[:,[3,1,2]].round(1))
print('\n')
dt = pd.DataFrame(data = d)
dt = dt[dt['Inferred Doubling Time'] < 100]
dt.plot.bar(x = 'States', y = 'Inferred Doubling Time', yerr='95%CI', legend=False,figsize=(10,5), fontsize="x-large");
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.ylabel('Inferred Doubling Time (Days)', fontsize="x-large")
plt.xlabel('US States', fontsize="x-large")
plt.title('Inferred Doubling Time of Cumulative COVID-19 Cases in US States. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()
err = pd.DataFrame([errors,[float('NaN') for e in errors]]).T
err.index=states
err.columns = ['Inferred Doubling Time', 'Recent Doubling Time']
print('\n')
dt = pd.DataFrame({'Inferred Doubling Time': inferreddoublingtime,'Recent Doubling Time': recentdoublingtime}, index=states)
dt = dt[dt['Recent Doubling Time'] < 100]
dt.plot.bar(yerr=err, figsize=(10,5), fontsize="x-large")
plt.ylabel('Doubling Time (Days)', fontsize="x-large")
plt.xlabel('US States', fontsize="x-large")
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.title('Doubling Time of Cumulative COVID-19 Cases in US States. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()